#!/usr/bin/env python3
# Author: Armit
# Create Time: 2022/11/18 

# process ARIP processed data

from pathlib import Path
import pickle as pkl
from time import time
from datetime import datetime

import pandas as pd
from pandas_profiling import ProfileReport

from data import *
from plot_arip import plot_stats

DATA_FILE = '22D_UR.csv'


def df_memory_usage(df:pd.DataFrame) -> float:
  return sum(df.memory_usage(deep=True)) / 2 ** 20


def preprocess():
  # perfcount 
  ts = time()

  # log stats
  logfp = open(ARIP_STAT_FILE, 'w', encoding='utf-8')
  
  def log(s:str):
    logfp.write(str(s))
    logfp.write('\n')
    logfp.flush()
    print(s)

  log(f'>> now ts: {datetime.now()!s}')

  # load data
  df = pd.read_csv(Path(RDATA_PATH) / DATA_FILE)
  log(f'>> len(df): {len(df)}')
  log(f'>> df.columns({len(df.columns)}): {list(df.columns)}')
  log(f'>> df.memory_usage: {df_memory_usage(df):.3f} MB')
  print(df.head(n=5))

  # filter & sort columns
  df = df[FEATURE_ALL]
  columns = list(df.columns)
  log(f'>> filtered df.columns({len(columns)}): {columns}')

  # ask for de-duplicate
  df_dedup = df.drop_duplicates()
  len_df, len_dfd = len(df), len(df_dedup)
  if len_df != len_dfd:
    print(f'<< duplicated lines detected {len_df} => {len_dfd} ({1 - len_dfd / len_df:.3%})')
    opt = ''
    while opt.lower() not in ['y', 'n']:
      opt = input('>> remove duplicates? (enter y or n): ').strip().lower()
      if opt == 'y':
        df = df_dedup
        log(f'>> de-duplicated len(df): {len(df)}')
        break
      elif opt == 'n':
        break
  
  # make report
  fp = ARIP_REPORT_FILE
  if not fp.exists():
    ProfileReport(df, minimal=True).to_file(fp)

  # dtype convert for storage
  cat_dict = {}          # 'feat_name': ['cat1', 'cat2', ...]
  for ft in columns:
    if ft in FEATURE_NUM:
      v_h = df[ft].astype(DTYPE_PROCESS['num'])
      v_l = df[ft].astype(DTYPE_STORAGE['num'])
      avg_d = (v_h - v_l.astype(DTYPE_PROCESS['num'])).abs().mean()
      avg_v = v_h.abs().mean()
      log(f'  {ft} is numerical: avg={v_h.mean()}, std={v_h.std()}; precision loss: {avg_d} ({avg_d / avg_v:.3%})')
      df[ft] = v_l
    elif ft in FEATURE_CAT:
      cat_dict[ft] = cats = sorted(set(df[ft].to_list()))
      mapping = { c: i for i, c in enumerate(cats) }
      log(f'  {ft} is categorical: ord={len(cats)}')
      df[ft] = df[ft].map(lambda e: mapping[e]).astype(DTYPE_STORAGE['cat'])
    else:
      log(f'<< unknown column: {ft}')
      breakpoint()
  
  log(f'>> reduced df.memory_usage: {df_memory_usage(df):.3f} MB')

  # save data
  sz = save_df(df)
  log(f'>> file size: {sz:.3f} MB')

  CatDict.save(cat_dict)

  # perfcount
  log(f'>> done in {time() - ts:.3f}s')

  # log close
  logfp.close()


if __name__ == '__main__':
  preprocess()
  plot_stats()
